package com.dny.asmtop.op;

import com.dny.asmtop.ASMMethodUtils;
import com.dny.asmtop.Command;
import com.dny.asmtop.MethodContext;
import jdk.internal.org.objectweb.asm.Type;
import jdk.internal.org.objectweb.asm.commons.GeneratorAdapter;

import java.util.Objects;

import static jdk.internal.org.objectweb.asm.Opcodes.*;
import static jdk.internal.org.objectweb.asm.Type.getType;

/**
 * Created by jlutt on 2018-01-17.
 * 四则运算
 *
 * @author jlutt
 */
public class CommandArithmetic implements Command {

  private final Operation op;

  private final Command left;

  private final Command right;

  public enum Operation {
    ADD(IADD, "+"), SUB(ISUB, "-"), MUL(IMUL, "*"), DIV(IDIV, "/"), REM(IREM, "%");

    private final int opCode;
    private final String symbol;

    Operation(int opCode, String symbol) {
      this.opCode = opCode;
      this.symbol = symbol;
    }
  }

  public CommandArithmetic(Operation op, Command left, Command right) {
    this.op = op;
    this.left = left;
    this.right = right;
  }

  public static Class<?> unifyArithmeticTypes(Class<?>... dataTypes) {
    Class<?> resultType = null;
    int resultOrder = 0;

    for (Class<?> dataType : dataTypes) {
      Class<?> t;
      int order;
      if (dataType == Byte.TYPE ||
          dataType == Short.TYPE ||
          dataType == Character.TYPE ||
          dataType == Integer.TYPE) {
        t = Integer.TYPE;
        order = 1;
      } else if (dataType == Long.TYPE) {
        t = Long.TYPE;
        order = 2;
      } else if (dataType == Float.TYPE) {
        t = Float.TYPE;
        order = 3;
      } else if (dataType == Double.TYPE) {
        t = Double.TYPE;
        order = 4;
      } else
        throw new IllegalArgumentException();
      if (resultType == null || order > resultOrder) {
        resultType = t;
        resultOrder = order;
      }
    }

    return resultType;
  }

  @Override
  public Type type(MethodContext context) {
    Type leftType = left.type(context);
    Type rightType = right.type(context);
    if (ASMMethodUtils.isWrapperType(leftType)) {
      leftType = ASMMethodUtils.unwrap(leftType);
    }
    if (ASMMethodUtils.isWrapperType(rightType)) {
      rightType = ASMMethodUtils.unwrap(rightType);
    }
    return getType(unifyArithmeticTypes(ASMMethodUtils.getJavaType(leftType), ASMMethodUtils.getJavaType(rightType)));
  }

  @Override
  public Type generator(MethodContext context) {
    GeneratorAdapter g = context.getGeneratorAdapter();
    Command leftVar = left;
    Command rightVar = right;
    if (ASMMethodUtils.isWrapperType(leftVar.type(context))) {
      leftVar.generator(context);
      g.unbox(ASMMethodUtils.unwrap(leftVar.type(context)));
      VarLocal newLeftVar = ASMMethodUtils.newLocal(context, ASMMethodUtils.unwrap(leftVar.type(context)));
      newLeftVar.storeLocal(g);
      leftVar = newLeftVar;
    }
    if (ASMMethodUtils.isWrapperType(rightVar.type(context))) {
      rightVar.generator(context);
      g.unbox(ASMMethodUtils.unwrap(rightVar.type(context)));
      VarLocal newRightVar = ASMMethodUtils.newLocal(context, ASMMethodUtils.unwrap(rightVar.type(context)));
      newRightVar.storeLocal(g);
      rightVar = newRightVar;
    }
    Type resultType = getType(unifyArithmeticTypes(
        ASMMethodUtils.getJavaType(leftVar.type(context)), ASMMethodUtils.getJavaType(rightVar.type(context))));
    if (leftVar.type(context) != resultType) {
      leftVar.generator(context);
      g.cast(leftVar.type(context), resultType);
      VarLocal newLeftVar = ASMMethodUtils.newLocal(context, resultType);
      newLeftVar.storeLocal(g);
      leftVar = newLeftVar;
    }
    if (rightVar.type(context) != resultType) {
      rightVar.generator(context);
      g.cast(rightVar.type(context), resultType);
      VarLocal newRightVar = ASMMethodUtils.newLocal(context, resultType);
      newRightVar.storeLocal(g);
      rightVar = newRightVar;
    }
    leftVar.generator(context);
    rightVar.generator(context);
    g.visitInsn(resultType.getOpcode(op.opCode));
    return resultType;
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) return true;
    if (o == null || getClass() != o.getClass()) return false;
    CommandArithmetic that = (CommandArithmetic) o;
    return op == that.op &&
        Objects.equals(left, that.left) &&
        Objects.equals(right, that.right);
  }

  @Override
  public int hashCode() {
    return Objects.hash(op, left, right);
  }
}
